import torch, torch.nn.functional as F
from tree_builder import Uext_batch_from_tree_lists_HMH
from losses import loss_diversity_from_S, loss_reconstruction_from_treeG

def split_batch_to_graphs(batch):
    data_list = batch.to_data_list()
    X_list, edge_index_list, y_list = [], [], []
    for data in data_list:
        x = data.x
        if x is None or x.numel() == 0:
            import torch
            x = torch.ones(data.num_nodes, 1, dtype=torch.float32)
        X_list.append(x)
        edge_index_list.append(data.edge_index)
        y = data.y.view(-1)[0].long()
        y_list.append(y)
    return X_list, edge_index_list, y_list

def run_one_epoch(loader, model, opt=None, train=True, device="cpu", levels=4):
    if train: model.train()
    else: model.eval()
    total_loss = 0.0; total_correct = 0; total_graphs = 0

    for batch in loader:
        X_list, edge_index_list, y_list = split_batch_to_graphs(batch)
        (U_batch, eidx_batch, n_nodes_batch, n_edges_batch, feats_batch, tree_batch, S_batch) = Uext_batch_from_tree_lists_HMH(
            X_list, edge_index_list,
            levels=levels, ratio=0.8,
            lam=0.1, k_feat=4, k_diff=4, t_heat=0.6, cheb_order=25,
            alpha=(0.5,0.5), device="cpu",
            assign_method="sinkhorn", tau=0.9, sinkhorn_iters=10, seed=42
        )
        for i in range(len(U_batch)):
            logits_nodes = model(U_batch[i], feats_batch[i], tree_batch[i])
            logits_graph = logits_nodes.mean(dim=0, keepdim=True)
            yi = torch.as_tensor([y_list[i].item()], dtype=torch.long, device=logits_graph.device)
            L_ce = F.cross_entropy(logits_graph, yi)
            L_div = loss_diversity_from_S(S_batch[i], device=logits_graph.device)
            L_rec = loss_reconstruction_from_treeG(tree_batch[i], device=logits_graph.device)
            L_total = 0.8 * L_ce + 0.1 * L_div
            if train:
                opt.zero_grad(); L_total.backward(); opt.step()
            total_loss += L_total.item()
            pred = logits_graph.argmax(dim=1)
            total_correct += int((pred == yi).sum().item())
            total_graphs  += 1
    avg_loss = total_loss / max(total_graphs, 1)
    acc = total_correct / max(total_graphs, 1)
    return avg_loss, acc

def train_model(model, train_loader, val_loader, test_loader, device="cpu", epochs=150, levels=4):
    opt = torch.optim.Adam(model.parameters(), lr=3e-3, weight_decay=1e-4)
    best_val, best_test_at_val = -1.0, 0.0
    for epoch in range(1, epochs + 1):
        tr_loss, tr_acc = run_one_epoch(train_loader, model, opt, True, device, levels)
        va_loss, va_acc = run_one_epoch(val_loader,   model, None, False, device, levels)
        te_loss, te_acc = run_one_epoch(test_loader,  model, None, False, device, levels)
        if va_acc > best_val:
            best_val = va_acc; best_test_at_val = te_acc
        print(f"Epoch {epoch:03d} | train acc {tr_acc:.3f} | val acc {va_acc:.3f} | test acc {te_acc:.3f} | best@val {best_test_at_val:.3f}")

